import histomicstk as htk

import numpy as np
import scipy as sp

import skimage.io
import skimage.measure
import skimage.color

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
%matplotlib inline

# Some nice default configuration for plots
plt.rcParams['figure.figsize'] = 10, 10
plt.rcParams['image.cmap'] = 'gray'
titlesize = 24
longitudinal_image_file = 'longitudinal.png'
long_im_input = skimage.io.imread(longitudinal_image_file)[:,:,:3]
plt.imshow(long_im_input)
_ = plt.title('Longitudinal Cellular Orientation', fontsize=16)
transverse_image_file = 'transverse.png'
trans_im_input = skimage.io.imread(transverse_image_file)[:,:,:3]
plt.imshow(trans_im_input)
_ = plt.title('Transverse Cellular Orientation', fontsize=16)
ref_image_file = ('6070-7712.png')  # L1.png
im_reference = skimage.io.imread(ref_image_file)[:,:,:3]

# get mean and stddev of reference image in lab space
mean_ref, std_ref = htk.preprocessing.color_conversion.lab_mean_std(im_reference)
long_im_nmzd = htk.preprocessing.color_normalization.reinhard(long_im_input, mean_ref, std_ref)
trans_im_nmzd = htk.preprocessing.color_normalization.reinhard(trans_im_input, mean_ref, std_ref)
plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.imshow(im_reference)
_ = plt.title('Reference Image', fontsize=titlesize)

plt.subplot(1, 2, 2)
plt.imshow(long_im_nmzd)
_ = plt.title('Normalized Longitudinal Input Image', fontsize=titlesize)
plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.imshow(im_reference)
_ = plt.title('Reference Image', fontsize=titlesize)

plt.subplot(1, 2, 2)
plt.imshow(trans_im_nmzd)
_ = plt.title('Normalized Transverse Input Image', fontsize=titlesize)
# create stain to color map
stainColorMap = {
    'hematoxylin': [0.65, 0.70, 0.29],
    'eosin':       [0.07, 0.99, 0.11],
    'dab':         [0.27, 0.57, 0.78],
    'null':        [0.0, 0.0, 0.0]
}

# specify stains of input image
stain_1 = 'hematoxylin'   # nuclei stain
stain_2 = 'eosin'         # cytoplasm stain
stain_3 = 'null'          # set to null if input contains only two stains

# create stain matrix
W = np.array([stainColorMap[stain_1],
              stainColorMap[stain_2],
              stainColorMap[stain_3]]).T
long_im_stains = htk.preprocessing.color_deconvolution.color_deconvolution(long_im_nmzd, W).Stains

# Display results
plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.imshow(long_im_stains[:, :, 0])
plt.title(stain_1, fontsize=titlesize)

plt.subplot(1, 2, 2)
plt.imshow(long_im_stains[:, :, 1])
_ = plt.title(stain_2, fontsize=titlesize)
trans_im_stains = htk.preprocessing.color_deconvolution.color_deconvolution(trans_im_nmzd, W).Stains

# Display results
plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.imshow(trans_im_stains[:, :, 0])
plt.title(stain_1, fontsize=titlesize)

plt.subplot(1, 2, 2)
plt.imshow(trans_im_stains[:, :, 1])
_ = plt.title(stain_2, fontsize=titlesize)
# get nuclei/hematoxylin channel
long_im_nuclei_stain = long_im_stains[:, :, 0]
trans_im_nuclei_stain = trans_im_stains[:, :, 0]

# segment foreground
foreground_threshold = 60

long_im_fgnd_mask = sp.ndimage.morphology.binary_fill_holes(
    long_im_nuclei_stain < foreground_threshold)

trans_im_fgnd_mask = sp.ndimage.morphology.binary_fill_holes(
    trans_im_nuclei_stain < foreground_threshold)

# run adaptive multi-scale LoG filter
min_radius = 10
max_radius = 15

long_im_log_max, long_im_sigma_max = htk.filters.shape.cdog(
    long_im_nuclei_stain, long_im_fgnd_mask,
    sigma_min=min_radius * np.sqrt(2),
    sigma_max=max_radius * np.sqrt(2)
)

trans_im_log_max, trans_im_sigma_max = htk.filters.shape.cdog(
    trans_im_nuclei_stain, trans_im_fgnd_mask,
    sigma_min=min_radius * np.sqrt(2),
    sigma_max=max_radius * np.sqrt(2)
)

# detect and segment nuclei using local maximum clustering
local_max_search_radius = 10

long_im_nuclei_seg_mask, seeds, maxima = htk.segmentation.nuclear.max_clustering(
    long_im_log_max, long_im_fgnd_mask, local_max_search_radius)

trans_im_nuclei_seg_mask, seeds, maxima = htk.segmentation.nuclear.max_clustering(
    trans_im_log_max, trans_im_fgnd_mask, local_max_search_radius)

# filter out small objects
min_nucleus_area = 80

long_im_nuclei_seg_mask = htk.segmentation.label.area_open(
    long_im_nuclei_seg_mask, min_nucleus_area).astype(np.int)

trans_im_nuclei_seg_mask = htk.segmentation.label.area_open(
    trans_im_nuclei_seg_mask, min_nucleus_area).astype(np.int)

# compute nuclei properties
LongObjProps = skimage.measure.regionprops(long_im_nuclei_seg_mask)
TransObjProps = skimage.measure.regionprops(trans_im_nuclei_seg_mask)

print('Number of nuclei in longitudinal image = ', len(LongObjProps))
print('Number of nuclei in transverse image = ', len(TransObjProps))
Number of nuclei in longitudinal image =  43
Number of nuclei in transverse image =  122
plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.imshow(skimage.color.label2rgb(long_im_nuclei_seg_mask, long_im_input, bg_label=0), origin='lower')
plt.title('Nuclei segmentation mask overlay', fontsize=titlesize)

plt.subplot(1, 2, 2)
plt.imshow( long_im_input )
plt.xlim([0, long_im_input.shape[1]])
plt.ylim([0, long_im_input.shape[0]])
plt.title('Nuclei bounding boxes', fontsize=titlesize)

for i in range(len(LongObjProps)):

    c = [LongObjProps[i].centroid[1], LongObjProps[i].centroid[0], 0]
    width = LongObjProps[i].bbox[3] - LongObjProps[i].bbox[1] + 1
    height = LongObjProps[i].bbox[2] - LongObjProps[i].bbox[0] + 1

    cur_bbox = {
        "type":        "rectangle",
        "center":      c,
        "width":       width,
        "height":      height,
    }

    plt.plot(c[0], c[1], 'g+')
    mrect = mpatches.Rectangle([c[0] - 0.5 * width, c[1] - 0.5 * height] ,
                               width, height, fill=False, ec='g', linewidth=2)
    plt.gca().add_patch(mrect)
plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.imshow(skimage.color.label2rgb(trans_im_nuclei_seg_mask, trans_im_input, bg_label=0), origin='lower')
plt.title('Nuclei segmentation mask overlay', fontsize=titlesize)

plt.subplot(1, 2, 2)
plt.imshow( trans_im_input )
plt.xlim([0, trans_im_input.shape[1]])
plt.ylim([0, trans_im_input.shape[0]])
plt.title('Nuclei bounding boxes', fontsize=titlesize)

for i in range(len(TransObjProps)):

    c = [TransObjProps[i].centroid[1], TransObjProps[i].centroid[0], 0]
    width = TransObjProps[i].bbox[3] - TransObjProps[i].bbox[1] + 1
    height = TransObjProps[i].bbox[2] - TransObjProps[i].bbox[0] + 1

    cur_bbox = {
        "type":        "rectangle",
        "center":      c,
        "width":       width,
        "height":      height,
    }

    plt.plot(c[0], c[1], 'g+')
    mrect = mpatches.Rectangle([c[0] - 0.5 * width, c[1] - 0.5 * height] ,
                               width, height, fill=False, ec='g', linewidth=2)
    plt.gca().add_patch(mrect)
LongObjPropsTable = skimage.measure.regionprops_table(long_im_nuclei_seg_mask,
                                                     properties=('label',
                                                                 'orientation',
                                                                'major_axis_length',
                                                                'minor_axis_length',)
                                                     )
TransObjPropsTable = skimage.measure.regionprops_table(trans_im_nuclei_seg_mask,
                                                       properties=('label',
                                                                 'orientation',
                                                                'major_axis_length',
                                                                'minor_axis_length')
                                                     )
long_orientation = LongObjPropsTable['orientation']
trans_orientation = TransObjPropsTable['orientation']
_ = plt.hist(long_orientation)
plt.show()
_ = plt.hist(trans_orientation)
plt.show()
LongObjPropsTable['minor_axis_length']
array([17.82988548, 26.57283721, 32.62511918, 17.9561122 , 24.04859764,
       25.90507908, 40.4939446 , 35.97834554, 22.36994041, 14.45953977,
       40.12733974, 21.8516542 , 12.42029426, 23.54320991, 20.64609371,
       24.50122929, 31.37118343, 23.3645544 , 18.90967646, 28.31049082,
       27.86635531, 15.762809  , 32.84102401, 31.56465618, 18.39091073,
       39.147678  , 24.17782857, 20.93425473, 24.76587761, 34.92536675,
       33.35863685, 23.92101848, 25.98666127, 30.92181968, 17.00492511,
       27.91541626, 44.926728  , 16.78490142, 54.96896025, 28.69640483,
       20.71021864, 16.72920149, 18.5068537 ])